train

cpprb.train(ReplayBuffer buffer: ReplayBuffer, env, get_action: Callable, update_policy: Callable, *, max_steps: int = int(1e6), max_episodes: Optional[int] = None, batch_size: int = 64, n_warmups: int = 0, after_step: Optional[Callable] = None, done_check: Optional[Callable] = None, obs_update: Optional[Callable] = None, rew_sum: Optional[Callable[[float, Any], float]] = None, episode_callback: Optional[Callable[[int, int, float], Any]] = None, logger=None)

Train RL policy (model)

Parameters
  • buffer (ReplayBuffer) – Buffer to be used for training

  • env (gym.Enviroment compatible) – Environment to learn

  • get_action (Callable) – Callable taking obs and returning action

  • update_policy (Callable) – Callable taking sample, step, and episode, updating policy, and returning |TD|.

  • max_steps (int (optional)) – Maximum steps to learn. The default value is 1000000

  • max_episodes (int (optional)) – Maximum episodes to learn. The defaul value is None

  • n_warmups (int (optional)) – Warmup steps before sampling. The default value is 0 (No warmup)

  • after_step (Callable (optional)) – Callable converting from obs, returns of env.step(action), step, and episode to dict of a transition for ReplayBuffer.add. This function can also be used for step summary callback.

  • done_check (Callable (optional)) – Callable checking done

  • obs_update (Callable (optional)) – Callable updating obs

  • rew_sum (Callable[[float, Dict], float] (optional)) – Callable summarizing episode reward

  • episode_callback (Callable[[int, int, float], Any] (optional)) – Callable for episode summarization

  • logger (logging.Logger (optional)) – Custom Logger

Raises

ValueError: – When max_step is larger than size_t limit

Warning

cpprb.train is still beta release. API can be changed.